#!/bin/bash

# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama3.1-405b. 
# Please make sure you have run end_to_end/tpu/llama3.1/405b/1_test_llama3.1_405b.sh before running commands from this file. 

# The flow of this file is as follows:
# 1. Run decoding, finetuning of Llama3.1-405B with the converted checkpoint obtained from end_to_end/tpu/llama3.1/405b/1_test_llama3.1_405b.sh. Also, run pretraining of Llama3.1-70B
# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding.
# 3. Run decoding from the finetuned checkpoint from step 1


# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama3.1/405b/2_test_llama3.1_405b.sh
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3.1/405b/1_test_llama3.1_405b.sh
# Please note that in these two scripts (1_test_llama3.1_405b.sh and 2_test_llama3.1_405b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and 
# the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs.

set -ex

# Installing torch for deps in forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

export MODEL_VARIATION='llama3.1-405b'

if [ -z "${BASE_OUTPUT_PATH}" ]; then
    # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run
    # Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3.1/70b/1_test_llama3.1_70b.sh
    export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
    echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}"
fi


# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs://maxtext-dataset

# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands
export SCANNED_CHECKCKPOINT=gs://maxtext-llama/llama3.1_405b_bf16/scanned/0/items

export UNSCANNED_CHECKPOINT=gs://maxtext-llama/llama3.1_405b_bf16/unscanned/0/items

# We run finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning.
# We use a small per_device_batch_size and SGD optimizer for the model to fit on a v4-128. This config is only used for unit testing.
export FINETUNE_RUN_NAME=runner_finetune
python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_type=synthetic tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${SCANNED_CHECKCKPOINT} per_device_batch_size=0.25 ici_tensor_parallelism=4 run_name=${FINETUNE_RUN_NAME} steps=10 enable_checkpointing=false model_name=${MODEL_VARIATION} logits_dot_in_fp32=false weight_dtype=bfloat16 opt_type=sgd

# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${SCANNED_CHECKCKPOINT} per_device_batch_size=0.0625 ici_tensor_parallelism=4 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product weight_dtype=bfloat16 prompt="I love to"

# We also test whether the forward pass logits match the golden logits for Llama3.1-405B
python3 -m tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${UNSCANNED_CHECKPOINT} run_name=forward_pass_test per_device_batch_size=0.0625 ici_tensor_parallelism=4 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype=float32 activations_in_float32=true matmul_precision=float32 weight_dtype=float32 async_checkpointing=false --max_kl_div=1e-4
